Problematic projections

Author

Juho Timonen

Setup

Show code
library(rstan)
Warning: package 'rstan' was built under R version 4.1.2
Loading required package: StanHeaders
Loading required package: ggplot2
Warning: package 'ggplot2' was built under R version 4.1.2
rstan (Version 2.21.5, GitRev: 2e1f913d3ca3)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
Show code
library(posterior)
Warning: package 'posterior' was built under R version 4.1.2
This is posterior version 1.2.1

Attaching package: 'posterior'
The following objects are masked from 'package:rstan':

    ess_bulk, ess_tail
The following objects are masked from 'package:stats':

    mad, sd, var
Show code
library(ggplot2)
library(mgcv)
Loading required package: nlme
This is mgcv 1.8-35. For overview type 'help("mgcv-package")'.
Show code
library(dimreduce)
This is dimreduce version 0.2.2
Show code
source("R/rstan.R")
source("R/search.R")
source("R/project.R")
source("R/project_old.R") # for project_sigma()
source("R/simulate.R")
Warning: package 'MASS' was built under R version 4.1.2
Show code
source("R/plotting.R")
source("R/selection.R")
source("R/hs_smooth.R")
source("R/relevances.R")
options(projpred.extra_verbose = TRUE)

CHAINS <- 1
ITER <- 600
CN <- list(adapt_delta = 0.9)

# Simulation setup
rho <- 0.7
sigma <- 3
rel_true <- c(1, 1, 0, 0, 1, 0, 0, 0) # true relevances
n <- 600
D <- length(rel_true)
set.seed(7899) # fix random seed

Data simulation

Show code
# Run sim and create data setup
dat <- simulate(n, rho, rel_true, sigma)
n <- length(dat$y)
splt <- create_split(n, 0.5)
ds <- list(dat = dat, split = splt)

Fitting reference model

Show code
# Model setup
# Create the Stan model
mod <- rstan::stan_model("stan/model_sample.stan")
ms <- list(
  stan_model = mod,
  B = 24,
  scale_bf = 1.5
)

dat_df <- create_data_frame(ds, test = FALSE)


fmf1 <- fit_full_model(ds, ms,
  chains = CHAINS, iter = ITER, control = CN,
  thin = 10,
  seed = 507949442
)
Creating input.
 * printing L
[1] 5.145850 4.941741 5.021478 4.307805 4.232649 4.840409 4.988403 4.619531
 * printing L
[1] 5.145850 4.941741 5.021478 4.307805 4.232649 4.840409 4.988403 4.619531
Fitting full model.
Warning: The largest R-hat is 1.45, indicating chains have not mixed.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#r-hat
Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#bulk-ess
Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#tail-ess
Sampling done.
Show code
# Use SPCA reference model
# fmf2 <- fit_full_model(ds, ms,
#  chains = CHAINS, iter = ITER, control = CN,
#  thin = 10, nctot = 5
# )

Ranking the variables

Show code
# Rank and print the ordering
cv1 <- comp_vars(fmf1$fit, fmf1$J)
path <- sort(cv1, index.return = T, decreasing = T)$ix
print(path)
[1] 1 5 2 8 3 7 6 4

Visualizing the reference model posterior

Show code
pa <- plot(fmf1$fit, pars = c("alpha")) + ggtitle("Magnitudes")
ci_level: 0.8 (80% intervals)
outer_level: 0.95 (95% intervals)
Show code
pb <- plot(fmf1$fit, pars = c("ell")) + ggtitle("Lengthscales")
ci_level: 0.8 (80% intervals)
outer_level: 0.95 (95% intervals)
Show code
ppost <- ggpubr::ggarrange(pa, pb, nrow = 1, ncol = 2)
ppost

Performing the projections

Show code
# Run selection with projection predictive method
options(pp.threshold = Inf) # no stopping condition
options(min.sp = 1e-1) # lower bound for smoothing parameter

# Use HS GP basis
options(bs.type = "hs")
sel_pp_hs4 <- selection_pp(fmf1, path)
Performing forward search.
Running a ProjectionForwardSearch with option = gam 
Stopping threshold for explained variance = Inf 
Projecting to submodel [] 
Step 1/8.
 * using pre-defined search path
Projecting to submodel [1] 
 * 1 - kl_current/kl_empty = 0.5188
Step 2/8.
 * using pre-defined search path
Projecting to submodel [1 5] 
 * 1 - kl_current/kl_empty = 0.8724
Step 3/8.
 * using pre-defined search path
Projecting to submodel [1 5 2] 
 * 1 - kl_current/kl_empty = 0.9732
Step 4/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8] 
 * 1 - kl_current/kl_empty = 0.9896
Step 5/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3] 
 * 1 - kl_current/kl_empty = 0.9942
Step 6/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7] 
 * 1 - kl_current/kl_empty = 0.9966
Step 7/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7 6] 
 * 1 - kl_current/kl_empty = 0.9988
Step 8/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7 6 4] 
 * 1 - kl_current/kl_empty = 1
Search done.
Show code
# Use thin plate spline basis
options(bs.type = "tp")
sel_pp_tp <- selection_pp(fmf1, path)
Performing forward search.
Running a ProjectionForwardSearch with option = gam 
Stopping threshold for explained variance = Inf 
Projecting to submodel [] 
Step 1/8.
 * using pre-defined search path
Projecting to submodel [1] 
 * 1 - kl_current/kl_empty = 0.5184
Step 2/8.
 * using pre-defined search path
Projecting to submodel [1 5] 
 * 1 - kl_current/kl_empty = 0.8707
Step 3/8.
 * using pre-defined search path
Projecting to submodel [1 5 2] 
 * 1 - kl_current/kl_empty = 0.9699
Step 4/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8] 
 * 1 - kl_current/kl_empty = 0.9862
Step 5/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3] 
 * 1 - kl_current/kl_empty = 0.9903
Step 6/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7] 
 * 1 - kl_current/kl_empty = 0.9927
Step 7/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7 6] 
 * 1 - kl_current/kl_empty = 0.9951
Step 8/8.
 * using pre-defined search path
Projecting to submodel [1 5 2 8 3 7 6 4] 
 * 1 - kl_current/kl_empty = 0.9963
Search done.
Show code
sels <- list(sel_pp_hs4, sel_pp_tp)
names(sels) <- c("hs4", "tp")
plot_mlpd <- function(sels, test = TRUE) {
  get_mlpd_test <- function(x) {
    x$search$history$mlpd_test
  }
  get_mlpd_train <- function(x) {
    x$search$history$mlpd_train
  }
  if (test) {
    mlpd <- data.frame(sapply(sels, get_mlpd_test))
    ylabb <- "Test MLPD"
  } else {
    mlpd <- data.frame(sapply(sels, get_mlpd_train))
    ylabb <- "Train MLPD"
  }

  mlpd$num_vars <- 0:8
  df <- reshape2::melt(mlpd, id.vars = "num_vars")
  ggplot(df, aes(x = num_vars, y = value, group = variable, color = variable)) +
    geom_line() +
    labs(color = "base") +
    ylab(ylabb) +
    geom_point() +
    theme(legend.position = "top")
}
paths <- sapply(sels, function(x) {
  x$search$path
})
plt1 <- plot_mlpd(sels, FALSE)
plt2 <- plot_mlpd(sels, TRUE)
plt_mlpd <- ggpubr::ggarrange(plt1, plt2)
plt_mlpd

Rerun projection

Show code
# Run projection again for problematic model
a <- fmf1
subm <- c(1, 5, 2)
options(bs.type = "hs")
p1 <- project_model(subm, a$fit, a$sd)
Projecting to submodel [1 5 2] 
Show code
options(bs.type = "tp")
p2 <- project_model(subm, a$fit, a$sd)
Projecting to submodel [1 5 2] 
Show code
projs <- list(hs = p1, tp = p2)

Visualizations of the problem

Identifying problematic draw/data point

Show code
x_test <- a$ds$dat$x[a$ds$split$test, ]
x_train <- a$ds$dat$x[a$ds$split$train, ]
train_ranges <- apply(x_train, 2, range)
get_mpt <- function(x) {
  p <- x$pd$mu_proj_test
  lpd <- rowMeans(x$all_lpd_test)
  cbind(x_test, lpd, p)
}
mu_proj_test <- lapply(projs, get_mpt)

# Helper function
to_df <- function(mat) {
  a <- data.frame(mat)
  df <- reshape(a,
    direction = "long", varying = 10:ncol(a),
    sep = ""
  ) # this causes the draw indices to be off by 9
  colnames(df)[9:12] <- c("mlpd", "draw_idx", "mu", "data_idx")
  df
}
df <- lapply(mu_proj_test, to_df)
nrows <- sapply(df, nrow)
df <- rbind(df[[1]], df[[2]])
b1 <- rep(names(nrows)[1], nrows[1])
b2 <- rep(names(nrows)[2], nrows[2])
df$base <- as.factor(c(b1, b2))
rownames(df) <- NULL

# Pick draw with most extreme mu
idx_max <- which(abs(df$mu) == max(abs(df$mu)))
idx_extreme <- df[idx_max, ]$draw_idx
df_extreme <- df[df$draw_idx == idx_extreme, ]
cat("Most problematic draw idx: ", idx_extreme-9, "\n", sep ="")
Most problematic draw idx: 26
Show code
# Plot to identify problematic draw/data point
plt <- ggplot(df, aes(
  x = draw_idx, y = mu, group = base, color = mlpd,
  pch = base
)) +
  geom_point(alpha = 0.75) + xlab("Draw idx (off by 9)")
plt

Plotting the projected weights

Show code
get_mt <- function(x) {
  x$metrics$mlpd_test
}
v1 <- sapply(projs, get_mt) # should match earlier values
w <- p1$pd$weights
create_plt_z <- function(w, idx) {
  df_w <- data.frame(w = w[, idx], x = rownames(w))
  term <- sapply(strsplit(df_w$x, "[.]"), function(x) {
    x[1]
  })
  df_w$term <- as.factor(term)
  ggplot(df_w, aes(x = x, y = w, color = term)) +
    geom_point()
}
d_idx <- idx_extreme - 9 # indexing off by 9
ell_all <- extract(fmf1$fit, "ell")$ell
cat("Printing all lengthscale draws\n")
Printing all lengthscale draws
Show code
print(round(ell_all,2))
          
iterations [,1]  [,2] [,3]  [,4] [,5]  [,6] [,7] [,8]
      [1,] 0.70  1.32 2.21  0.20 1.34  0.25 9.38 2.22
      [2,] 0.56  1.89 0.41  0.88 0.83  0.42 1.72 4.65
      [3,] 0.88  1.44 0.73  2.69 0.81  2.13 0.49 4.27
      [4,] 1.65  0.92 0.67  0.44 1.78  0.12 7.02 1.47
      [5,] 0.80  1.14 8.79  0.76 0.94  1.10 5.22 0.96
      [6,] 0.73  0.80 0.89  0.30 0.96  2.79 3.06 0.48
      [7,] 0.67  1.14 3.06  1.35 1.19  6.82 1.37 1.28
      [8,] 0.84  0.61 3.15  0.39 0.69  0.72 1.40 0.89
      [9,] 0.77  1.16 1.91  7.15 0.46  0.16 0.34 1.20
     [10,] 0.65  0.35 1.81  1.63 0.45  1.24 0.35 3.85
     [11,] 1.37  1.28 5.21  0.89 0.61  0.92 0.92 2.67
     [12,] 1.82  1.10 0.24  0.41 0.57  0.52 2.48 0.77
     [13,] 0.87  1.10 1.91  0.08 0.62  2.06 1.82 0.69
     [14,] 0.78 10.40 0.16  0.58 0.59  2.43 4.27 0.98
     [15,] 0.93  1.66 1.30  1.88 1.41  1.38 1.06 1.76
     [16,] 0.60  0.94 0.51  6.41 0.76  1.09 0.57 1.18
     [17,] 0.44  2.21 1.13  0.21 0.40  1.41 0.39 1.91
     [18,] 1.08  1.30 5.39  2.77 0.79  2.91 0.33 2.91
     [19,] 1.90  0.76 8.56  1.36 1.14  0.44 1.65 1.52
     [20,] 0.96  1.13 2.61  1.83 0.96  2.79 6.89 0.35
     [21,] 1.36  0.95 2.13  0.78 0.81  0.90 0.62 3.22
     [22,] 1.72  0.53 0.39  1.98 0.33  3.31 0.77 0.54
     [23,] 1.84  1.34 0.96  0.55 0.80 20.32 0.73 0.75
     [24,] 1.42  0.97 0.95  0.44 1.40  0.53 4.12 2.41
     [25,] 1.35  1.01 0.34  0.48 2.02  3.33 0.67 0.65
     [26,] 1.31  0.87 1.97  0.59 1.19  1.26 2.21 2.16
     [27,] 0.47  1.63 7.01 12.13 1.31  2.09 1.73 7.83
     [28,] 1.03  2.29 0.75  1.69 0.98  6.13 0.20 1.74
     [29,] 0.45  0.79 3.33  1.53 0.95  1.29 1.41 0.51
     [30,] 0.77  1.22 0.31  0.30 1.36  3.64 0.07 2.30
Show code
ell_extreme <- extract(fmf1$fit, "ell")$ell[d_idx,]
alpha_extreme <- extract(fmf1$fit, "alpha")$alpha[d_idx,]
cat("Printing the reference model lengthscales in the most extreme draw\n")
Printing the reference model lengthscales in the most extreme draw
Show code
print(round(ell_extreme, 2))
[1] 1.31 0.87 1.97 0.59 1.19 1.26 2.21 2.16
Show code
cat("Printing the reference model magnitude params in the most extreme draw\n")
Printing the reference model magnitude params in the most extreme draw
Show code
print(round(alpha_extreme,2))
[1] 1.64 1.08 0.66 0.76 0.74 0.40 0.22 1.29
Show code
plt_z_a <- create_plt_z(w, d_idx) + ggtitle(paste0("Weights, draw_idx = ", d_idx)) +
  theme(legend.position = "top")
plt_z_b <- create_plt_z(w, 1) + ggtitle("Weights, draw_idx = 1") +
  theme(legend.position = "top")
plt_z <- ggpubr::ggarrange(plt_z_a, plt_z_b)
plt_z

mgcv convergence

Show code
cols <- c("sig2", "gcv.ubre.dev", "deviance", "boundary", "fully.converged", "iter", "score.calls")
info <- projs$hs$pd$mgcv_info[, cols]
info$mlpd <- colMeans(projs$hs$all_lpd_test)
print(info, n=nrow(info))
# A tibble: 30 × 8
      sig2 gcv.ubre.dev deviance boundary fully.converged  iter score.ca…¹  mlpd
     <dbl>        <dbl>    <dbl>    <dbl>           <dbl> <dbl>      <dbl> <dbl>
 1 0.0962       0.106     26.3          0               1     6          8 -2.55
 2 0.183        0.197     50.9          0               1     7          8 -2.58
 3 0.0581       0.0617    16.4          0               1     6          7 -2.55
 4 0.0553       0.0596    15.4          0               1     6          7 -2.55
 5 0.00989      0.0110     2.66         0               1     6          7 -2.54
 6 0.207        0.219     58.9          0               1     8         12 -2.56
 7 0.114        0.122     31.8          0               1     7          8 -2.54
 8 0.0335       0.0375     9.00         0               1     5          5 -2.56
 9 0.372        0.397    104.           0               1     7          8 -2.55
10 0.0354       0.0402     9.34         0               1     6          6 -2.54
11 0.134        0.144     37.7          0               1     8          9 -2.52
12 0.0263       0.0291     7.12         0               1     6          6 -2.58
13 0.0313       0.0350     8.39         0               1     5          5 -2.60
14 0.0242       0.0261     6.69         0               1    15         27 -2.56
15 0.0152       0.0166     4.15         0               1     5          7 -2.55
16 0.0812       0.0892    22.2          0               1     5          6 -2.53
17 0.160        0.174     43.8          0               1     7          7 -2.62
18 0.0102       0.0114     2.74         0               1     5          6 -2.56
19 0.139        0.147     39.3          0               1     7          8 -2.55
20 0.0390       0.0427    10.7          0               1     6          8 -2.54
21 0.00785      0.00876    2.11         0               1     6          6 -2.52
22 0.102        0.117     27.0          0               1    13         13 -2.55
23 0.00816      0.00903    2.21         0               1     4          6 -2.57
24 0.0182       0.0199     4.98         0               1     6          8 -2.56
25 0.0399       0.0439    10.9          0               1     9         10 -2.58
26 0.350        0.367    100.           0               1     8         10 -2.54
27 0.00310      0.00345    0.837        0               1    14         15 -2.56
28 0.0453       0.0499    12.3          0               1     6          7 -2.56
29 0.00563      0.00634    1.50         0               1    13         13 -2.59
30 0.0960       0.103     26.7          0               1     9         11 -2.57
# … with abbreviated variable name ¹​score.calls
Show code
plot(info$iter, info$mlpd)

Show code
plot(info$gcv.ubre.dev, info$mlpd)

estimated smoothing parameters

Show code
sp <- projs$hs$pd$smoothing_params
print(sp)
          s(x1)       s(x5)        s(x2)
1  1.938497e+00  55.6764666 8.649611e-01
2  4.419989e+00   7.0041823 1.098680e+02
3  4.645350e+00  70.9626958 1.240838e+02
4  8.523210e+00  41.1683568 6.550375e+00
5  5.360737e+00   0.4914947 3.364001e+00
6  2.989235e+01 177.6235051 1.632473e+01
7  6.543220e+00  20.3985439 2.191127e+01
8  2.932785e+00   2.6387062 7.352246e-01
9  7.110747e+00  43.4238676 3.360341e+01
10 7.831289e-01   0.9934521 4.711414e-01
11 2.020841e+01   4.2682340 1.061311e+02
12 1.649227e+01   0.4292574 3.125159e+00
13 1.571155e+00   1.3853285 2.666179e+00
14 4.910036e-01   4.6633795 1.893168e+03
15 7.112003e+00   3.5101487 4.934997e+00
16 1.669106e+00   1.8668710 3.175538e+01
17 8.007559e-01   4.2873684 7.775241e+01
18 1.444625e+00   0.3060423 9.356137e+00
19 8.417770e+00  22.7607248 2.860944e+02
20 2.664828e+00   6.1941898 8.814355e+00
21 3.956993e-01   3.9659206 4.011270e+00
22 9.767915e-11   0.5522319 2.430648e+00
23 7.582025e+00   2.3646784 1.765834e+00
24 2.139890e+00   8.9856143 7.782869e+00
25 6.352352e+00  50.3718666 3.562629e-02
26 2.417612e+01  91.7965161 1.839197e+02
27 1.474844e-09  11.8271026 5.370825e+00
28 9.241905e+00   6.6035520 1.064805e+00
29 3.562564e-09   5.1934926 1.721493e+00
30 1.070164e+00 113.3286157 2.030760e+01
Show code
plot(info$mlpd, log10(sp$`s(x1)`), xlab="MLPD", ylab="log10(gamma_1)", main="Smoothing param for s(x1)")

The smoothing parameter is very small for the bad draw.

Other plots

Show code
LL <- a$sd$L # length 8
create_4_plots <- function(df, cc, gg) {
  aes1 <- aes_string(
    x = "data_idx", y = "mu", group = gg, color = cc, pch = gg
  )
  aes2 <- aes_string(
    x = "x1", y = "mu", group = gg, color = cc, pch = gg
  )
  aes3 <- aes_string(
    x = "x2", y = "mu", group = gg, color = cc, pch = gg
  )
  aes4 <- aes_string(
    x = "x5", y = "mu", group = gg, color = cc, pch = gg
  )
  plt_ex <- ggplot(df, aes1) +
    geom_point(alpha = 0.75) +
    xlab("test_point_idx")

  plt_ex_x1 <- ggplot(df, aes2) +
    geom_vline(xintercept = train_ranges[, 1], lty = 3) +
    geom_vline(xintercept = LL[1], lty = 2, col="red") +
    geom_vline(xintercept = - LL[1], lty = 2, col="red") +
    geom_point(alpha = 0.75) +
    theme(legend.position = "top")
  plt_ex_x2 <- ggplot(df, aes3) +
    geom_vline(xintercept = train_ranges[, 2], lty = 3) +
    geom_point(alpha = 0.75) +
    geom_vline(xintercept = LL[2], lty = 2, col="red") +
    geom_vline(xintercept = - LL[2], lty = 2, col="red") +
    theme(legend.position = "top")
  plt_ex_x5 <- ggplot(df, aes4) +
    geom_vline(xintercept = train_ranges[, 3], lty = 3) +
    geom_vline(xintercept = LL[5], lty = 2, col="red") +
    geom_vline(xintercept = - LL[5], lty = 2, col="red") +
    geom_point(alpha = 0.75) +
    theme(legend.position = "top")
  
  message("The red dashed lines show the box size [-L, L].")
  message("The black dotted lines show the training data range.")
  pp <- ggpubr::ggarrange(plt_ex_x1, plt_ex_x2, plt_ex_x5, nrow = 1, ncol = 3)
  ggpubr::ggarrange(plt_ex, pp,
    nrow = 2, ncol = 1
  )
}

plt_a <- create_4_plots(df_extreme, "base", "base")
Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
ℹ Please use tidy evaluation ideoms with `aes()`
The red dashed lines show the box size [-L, L].
The black dotted lines show the training data range.
Show code
df_extreme_hs <- df_extreme[which(df_extreme$base == "hs"), ]
df_extreme_tp <- df_extreme[which(df_extreme$base == "tp"), ]
plt_b <- create_4_plots(df_extreme_hs, "mlpd", NULL)
The red dashed lines show the box size [-L, L].
The black dotted lines show the training data range.
Show code
plt_c <- ggplot(df_extreme_hs, aes(x = x1, y = x2, color = mlpd)) +
  geom_point() +
  geom_vline(xintercept = train_ranges[, 1], lty = 3) +
  geom_hline(yintercept = train_ranges[, 2], lty = 3) +
  ggtitle("MLPD (using HS basis) at test points",
    subtitle = "Dotted lines = training data range"
  )

plt_a

Show code
plt_b

Show code
plt_c

Comparing the predictions (also against the test data)

Problematic draw

The predictions given the spline and GP bases match very well in points other than the edge cases.

Show code
ytest <- fmf1$sd$y_test
mu_hs <- df_extreme_hs$mu
mu_tp <- df_extreme_tp$mu
plot(mu_tp, mu_hs, main = "All")

Show code
plot(mu_tp, mu_hs, ylim=c(-6,6), main="No edge cases")

Show code
plot(mu_tp, ytest)

A different draw

Show code
df_other <- df[df$draw_idx == 10, ]
df_other_hs <- df_other[which(df_other$base == "hs"), ]
df_other_tp <- df_other[which(df_other$base == "tp"), ]
mu_hs <- df_other_hs$mu
mu_tp <- df_other_tp$mu
plot(mu_tp, mu_hs, main = "All")

Show code
plot(mu_tp, ytest)

Plotting the additive components

Reference model

Show code
plot_comps(fmf1$fit, fmf1$sd, ncol = 4, nrow = 2)

Show code
plot_comps(fmf1$fit, fmf1$sd, test = TRUE, ncol = 4, nrow = 2)

Submodel

Show code
pp <- plot_projection(subm, projs$hs, fmf1$sd)
pp[[1]] + ggtitle("Projected component f1")

Show code
pp[[2]] + ggtitle("Projected component f5")

Show code
pp[[3]] + ggtitle("Projected component f2")

Get ref model mean prediction

Show code
draws_ref <- get_draws(fmf1$fit)
mu_ref <- draws_ref$f
mu <- mu_ref[idx_extreme-9,]
x_train <- t(fmf1$sd$X_train)
df_save <- data.frame(cbind(x_train, mu)) # saved this